/*******************************************************************************
 * Copyright 2020 huanggefan.cn
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ******************************************************************************/

package core

import (
	"context"
	"fmt"
	"net"
	"runtime"
	"strconv"
	"time"

	"gorm.io/gorm/logger"
)

const (
	LogDebug int = iota
	LogInfo
	LogWarning
	LogError
)

const (
	logDebugPrefixWithColor   = "\033[34m%s [Debug  ]: %s\033[0m"
	logInfoPrefixWithColor    = "\033[32m%s [Info   ]: %s\033[0m"
	logWarningPrefixWithColor = "\033[33m%s [Warning]: %s\033[0m"
	logErrorPrefixWithColor   = "\033[31m%s [Error  ]: %s\033[0m"
)

const (
	logDebugPrefix   = "%s [Debug  ]: %s"
	logInfoPrefix    = "%s [Info   ]: %s"
	logWarningPrefix = "%s [Warning]: %s"
	logErrorPrefix   = "%s [Error  ]: %s"
)

type Logger struct {
	minDisplayLevel int

	useUDP     bool
	udpAddr    *net.UDPAddr
	udpConn    *net.UDPConn
	udpMaxByte int

	debugChan   chan string
	infoChan    chan string
	warningChan chan string
	errorChan   chan string
}

func NewLogger(level int) *Logger {
	ctx, finishInit := context.WithCancel(context.Background())
	defer finishInit() // 当初始化完成时，开始执行 goLog 协程

	logger := new(Logger)
	logger.init(ctx, level)

	return logger
}

func NewUDPLogger(level int, addr *net.UDPAddr, maxBytes int) (*Logger, error) {
	ctx, finishInit := context.WithCancel(context.Background())
	defer finishInit() // 当初始化完成时，开始执行 goLog 协程

	logger := new(Logger)
	logger.init(ctx, level)

	logger.useUDP = true
	logger.udpAddr = addr

	if maxBytes <= 0 {
		logger.udpMaxByte = 1472
	} else {
		logger.udpMaxByte = maxBytes
	}

	conn, err := net.DialUDP(logger.udpAddr.Network(), nil, logger.udpAddr)
	logger.udpConn = conn

	return logger, err
}

func (logger *Logger) getMsgWithSource(msg string) string {
	_, file, line, _ := runtime.Caller(2)
	source := file + ":" + strconv.Itoa(line)
	return source + " " + msg
}

func (logger *Logger) init(ctx context.Context, level int) {
	if level < LogDebug {
		logger.minDisplayLevel = LogDebug
	} else if level > LogError {
		logger.minDisplayLevel = LogError
	} else {
		logger.minDisplayLevel = level
	}

	logger.useUDP = false

	logger.debugChan = make(chan string, 10240)
	logger.infoChan = make(chan string, 10240)
	logger.warningChan = make(chan string, 10240)
	logger.errorChan = make(chan string, 10240)

	go func() {
		<-ctx.Done() // 初始化完成，开始执行 goLog 协程
		go logger.goLog()
	}()
}

func (logger *Logger) Debug(msg string) {
	logger.log(logger.getMsgWithSource(msg), LogDebug)
}

func (logger *Logger) Info(msg string) {
	logger.log(msg, LogInfo)
}

func (logger *Logger) Warning(msg string) {
	logger.log(logger.getMsgWithSource(msg), LogWarning)
}

func (logger *Logger) Error(msg string) {
	logger.log(logger.getMsgWithSource(msg), LogError)
}

func (logger *Logger) Printf(msg string, data ...interface{}) {
	logger.log(fmt.Sprintf("%s, %#v", msg, data), LogInfo)
}

func (logger *Logger) log(msg string, level int) {
	if level < logger.minDisplayLevel {
		return
	}

	now := time.Now().Format(timeLayout)
	var l string
	var stdoutL string

	switch level {
	case LogDebug:
		stdoutL = fmt.Sprintf(logDebugPrefixWithColor, now, msg)
		if logger.useUDP {
			l = fmt.Sprintf(logDebugPrefix, now, msg)
			logger.debugChan <- l
		}
	case LogInfo:
		stdoutL = fmt.Sprintf(logInfoPrefixWithColor, now, msg)
		if logger.useUDP {
			l = fmt.Sprintf(logInfoPrefix, now, msg)
			logger.infoChan <- l
		}
	case LogWarning:
		stdoutL = fmt.Sprintf(logWarningPrefixWithColor, now, msg)
		if logger.useUDP {
			l = fmt.Sprintf(logWarningPrefix, now, msg)
			logger.warningChan <- l
		}
	case LogError:
		stdoutL = fmt.Sprintf(logErrorPrefixWithColor, now, msg)
		if logger.useUDP {
			l = fmt.Sprintf(logErrorPrefix, now, msg)
			logger.errorChan <- l
		}
	}

	fmt.Println(stdoutL)
}

func (logger *Logger) goLog() {
	if !logger.useUDP { // 不需要UDP发送日志， 协程结束
		return
	}

	for {
		select {
		case l := <-logger.errorChan:
			logger.doUDPLog(l)
		default:
			select {
			case l := <-logger.errorChan:
				logger.doUDPLog(l)
			case l := <-logger.warningChan:
				logger.doUDPLog(l)
			case l := <-logger.infoChan:
				logger.doUDPLog(l)
			case l := <-logger.debugChan:
				logger.doUDPLog(l)
			}
		}
	}
}

func (logger *Logger) doUDPLog(l string) {
	ll := []byte(l)
	if len(ll) > logger.udpMaxByte {
		ll = ll[:logger.udpMaxByte]
	}

	if logger.udpConn != nil {
		_, _ = logger.udpConn.Write(ll)
	}
}

// Print: Interface for GORM debug log
func (logger *Logger) Print(values ...interface{}) {
	if LogDebug < logger.minDisplayLevel {
		return
	}

	now := time.Now().Format(timeLayout)
	var l string
	var stdoutL string

	getSQL := func() (source string, sql string) {
		if len(values) >= 3 {
			level := values[0].(string)
			if level == "sql" {
				source = values[1].(string)
				sql = values[3].(string)
				return
			}
		}
		return "", ""
	}

	source, sql := getSQL()
	stdoutL = fmt.Sprintf(logDebugPrefixWithColor, now, source+" "+sql)
	if logger.useUDP {
		l = fmt.Sprintf(logDebugPrefix, now, source+" "+sql)
		logger.debugChan <- l
	}

	fmt.Println(stdoutL)
}

func GetGormLogLevel() logger.LogLevel {
	switch FlagLogLevel {
	case LogDebug:
		return logger.Info
	case LogInfo:
		return logger.Warn
	case LogWarning:
		return logger.Error
	case LogError:
		return logger.Silent
	}
	return logger.Error
}
